/* Copyright (c) 2023, Arm Limited and Contributors. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
#include <assert.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

/* Give the log a name when the callback is defined to PW_LOG_* (as it is in the example). */
#ifdef PW_LOG_MODULE_NAME
#undef PW_LOG_MODULE_NAME
#endif /* PW_LOG_MODULE_NAME */
#define PW_LOG_MODULE_NAME "asan"

#include <iotsdk_sanitizers/asan.h>

/* Redzone defaults to 8 bytes either side of the user-requested region. */
#ifndef IOTSDK_ASAN_REDZONE_SIZE
#define IOTSDK_ASAN_REDZONE_SIZE 8
#endif /* ! IOTSDK_ASAN_REDZONE_SIZE */

#if IOTSDK_ASAN_REDZONE_SIZE < 4 || IOTSDK_ASAN_REDZONE_SIZE % 4 != 0
#error IOTSDK_ASAN_REDZONE_SIZE size should be >=4 and a multiple of 4
#endif

/* Quarantine list defaults to last 16 freed blocks. */
#ifndef IOTSDK_ASAN_QUARANTINE_SIZE
#define IOTSDK_ASAN_QUARANTINE_SIZE 16
#endif /* ! IOTSDK_ASAN_QUARANTINE_SIZE */

/* Ensure shadow arena's bounds are defined. */
#if !(defined(IOTSDK_ASAN_SHADOW_START) && defined(IOTSDK_ASAN_SHADOW_END))
#error IOTSDK_ASAN_SHADOW_START and/or IOTSDK_ASAN_SHADOW_END undefined
#endif /* !(IOTSDK_ASAN_SHADOW_START && IOTSDK_ASAN_SHADOW_END) */

/* Ensure heap bounds are defined. */
#if !(defined(IOTSDK_ASAN_SYM_HEAP_START) && defined(IOTSDK_ASAN_SYM_HEAP_END))
#error IOTSDK_ASAN_SYM_HEAP_START and/or IOTSDK_ASAN_SYM_HEAP_END undefined
#endif /* !(IOTSDK_ASAN_SYM_HEAP_START && IOTSDK_ASAN_SYM_HEAP_END) */

/* Ensure stack bounds are defined. */
#if !(defined(IOTSDK_ASAN_SYM_STACK_START) && defined(IOTSDK_ASAN_SYM_STACK_END))
#error IOTSDK_ASAN_SYM_STACK_START and/or IOTSDK_ASAN_SYM_STACK_END undefined
#endif /* !(IOTSDK_ASAN_SYM_HEAP_START && IOTSDK_ASAN_SYM_HEAP_END) */

/* Get caller return address from the stack (for error reporting). */
#define GET_CALLER_ADDR() (uintptr_t) __builtin_extract_return_addr(__builtin_return_address(0))

/* Indicate whether P satisfies alignment requirement ALIGN e.g. ALIGN=4 for 4 bytes requirement. */
#define IS_ALIGNED(P, ALIGN) ((((uintptr_t)(P)) & (((uintptr_t)(ALIGN)) - 1)) == 0)

/* Align P leftwards to the nearest alignment boundary ALIGN, e.g. ALIGN_LEFT(0xFFF4, 8) -> 0xFFF0. */
#define ALIGN_LEFT(P, ALIGN) ((P) - (((uintptr_t)(P)) & ((ALIGN)-1)))

/* Indicate if X is inside [A,B). */
#define BETWEEN(X, A, B) (((void *)(A)) <= ((void *)(X)) && ((void *)(X)) < ((void *)(B)))

/* Size of heap, stack & shadow regions. */
#define HEAP_SIZE   ((size_t)(IOTSDK_ASAN_SYM_HEAP_END - IOTSDK_ASAN_SYM_HEAP_START))
#define STACK_SIZE  ((size_t)(IOTSDK_ASAN_SYM_STACK_END - IOTSDK_ASAN_SYM_STACK_START))
#define SHADOW_SIZE ((size_t)(IOTSDK_ASAN_SHADOW_END - IOTSDK_ASAN_SHADOW_START))

/* Heap, stack & shadow boundary linker symbols. */
extern uint8_t IOTSDK_ASAN_SYM_HEAP_START[];
extern uint8_t IOTSDK_ASAN_SYM_HEAP_END[];
extern uint8_t IOTSDK_ASAN_SYM_STACK_START[];
extern uint8_t IOTSDK_ASAN_SYM_STACK_END[];
extern uint8_t IOTSDK_ASAN_SHADOW_START[];
extern uint8_t IOTSDK_ASAN_SHADOW_END[];

enum { READ_ACCESS = 1, WRITE_ACCESS = 2, SHADOW_BLKSIZE = 8 };

typedef enum {
    NOT_PRESENT,
    HEAP_REGION,
    STACK_REGION,
} region_type_t;

static const char *region_type_to_string[] = {"not present", "heap", "stack"};

/* Protected memory region info. */
struct mem_region {
    unsigned type : 3;
    const void *start;
    const void *end;
    void *shadow_start;
    size_t shadow_size;
};

/* If p is within the stack or heap, fill region (if not NULL) with information and return true. Otherwise set region
 * type (if region is not NULL) to NOT_PRESENT and return false.
 */
static bool get_mem_region(const void *p, struct mem_region *region)
{
    if (BETWEEN(p, IOTSDK_ASAN_SYM_HEAP_START, IOTSDK_ASAN_SYM_HEAP_END)) {
        if (region) {
            region->type = HEAP_REGION;
            region->start = IOTSDK_ASAN_SYM_HEAP_START;
            region->end = IOTSDK_ASAN_SYM_HEAP_END;
            region->shadow_start = IOTSDK_ASAN_SHADOW_START;
            region->shadow_size = HEAP_SIZE / SHADOW_BLKSIZE + HEAP_SIZE % SHADOW_BLKSIZE;
        }

        return true;
    }

    if (BETWEEN(p, IOTSDK_ASAN_SYM_STACK_START, IOTSDK_ASAN_SYM_STACK_END)) {
        if (region) {
            region->type = STACK_REGION;
            region->start = IOTSDK_ASAN_SYM_STACK_START;
            region->end = IOTSDK_ASAN_SYM_STACK_END;
            region->shadow_start = IOTSDK_ASAN_SHADOW_START + HEAP_SIZE / SHADOW_BLKSIZE + HEAP_SIZE % SHADOW_BLKSIZE;
            region->shadow_size = STACK_SIZE / SHADOW_BLKSIZE + STACK_SIZE % SHADOW_BLKSIZE;
        }

        return true;
    }

    if (region) {
        region->type = NOT_PRESENT;
    }

    return false;
}

/* Quarantine list, a list of recently-freed pointers.
 */
static struct quarantine {
    void *start;
    void *end;
} quarantine_list[IOTSDK_ASAN_QUARANTINE_SIZE];

static size_t quarantine_list_idx = 0;

static struct quarantine *get_quarantine(const void *p)
{
    for (int i = 0; i < IOTSDK_ASAN_QUARANTINE_SIZE; i++) {
        if (quarantine_list[i].start == NULL) {
            continue;
        }

        if (BETWEEN(p, quarantine_list[i].start, quarantine_list[i].end)) {
            return &quarantine_list[i];
        }
    }

    return NULL;
}

static void report_double_free(const void *p, uintptr_t caller)
{
    iotsdk_sanitizers_error_cb(
        "AddressSanitizer error: Double free of pointer 0x%08X [pc=0x%08X]", (uintptr_t)p, caller);
}

static void report_bad_access(const void *p, size_t size, int rw, uintptr_t caller)
{
    struct mem_region region;
    bool is_valid_region = get_mem_region(p, &region);
    assert(is_valid_region);

    bool is_read = rw == READ_ACCESS;
    bool is_quarantined = get_quarantine(p) != NULL;
    iotsdk_sanitizers_error_cb(
        "AddressSanitizer error: Invalid %s of %zu bytes at 0x%08X, %zu bytes inside %s%s [pc=0x%08X]",
        is_read ? "read" : "write",
        size,
        (uintptr_t)p,
        (size_t)(p - region.start),
        region_type_to_string[region.type],
        is_quarantined ? " and recently freed" : "",
        caller);
}

/* Compute shadow address for pointer 'p' which is in region 'region'.*/
static uint8_t *mem_to_shadow(const struct mem_region *region, const void *p)
{
    uintptr_t region_offset = (uintptr_t)p - (uintptr_t)region->start;
    return (uint8_t *)region->shadow_start + region_offset / SHADOW_BLKSIZE;
}

#define TURN_BIT_ON(X, POS)  ((X) |= 1 << (POS))
#define TURN_BIT_OFF(X, POS) ((X) &= ~(1 << (POS)))

static void poison_bytes(const struct mem_region *region, const void *p, size_t count)
{
    assert(region != NULL);
    assert(region->type != NOT_PRESENT && BETWEEN(p, region->start, region->end));
    assert(p != NULL);
    assert(IS_ALIGNED(p, sizeof(int)));
    assert(count > 0);

    uint8_t *shadow = mem_to_shadow(region, p);
    size_t i = p - ALIGN_LEFT(p, SHADOW_BLKSIZE); /* Skip bits not in region to be poisoned. */

    /* Set specific bits (poison 1 byte at a time).
     */
    for (; i < count; i++) {
        size_t shadow_byte = i / SHADOW_BLKSIZE;
        size_t shadow_byte_bit = i % SHADOW_BLKSIZE;
        TURN_BIT_ON(shadow[shadow_byte], shadow_byte_bit);
    }
}

static void unpoison_bytes(const struct mem_region *region, const void *p, size_t count)
{
    assert(region != NULL);
    assert(region->type != NOT_PRESENT && BETWEEN(p, region->start, region->end));
    assert(p != NULL);
    assert(count > 0);

    uint8_t *shadow = mem_to_shadow(region, p);
    size_t i = p - ALIGN_LEFT(p, SHADOW_BLKSIZE); /* Skip bits not in region to be unpoisoned. */

    /* Clear specific bits (unpoison 1 byte at a time).
     */
    for (; i < count; i++) {
        size_t shadow_byte = i / SHADOW_BLKSIZE;
        size_t shadow_byte_bit = i % SHADOW_BLKSIZE;
        TURN_BIT_OFF(shadow[shadow_byte], shadow_byte_bit);
    }
}

static void check_shadow_mem(const void *p, size_t size, int rw, uintptr_t caller)
{
    assert(size == 1 || size == 2 || size == 4 || size == 8);

    /* Find memory region.
     */
    struct mem_region region;
    if (!get_mem_region(p, &region)) {
        return;
    }

    /* Report error if the poison bit for the accessed byte(s) is set.
     * Size is 1, 2, 4 or 8 bytes as we are called from __asan_{load,store}{1,2,4,8})
     * Each bit in the shadow byte indicates whether the accessed by is poisoned or not.
     */
    uint8_t shadow_byte = *mem_to_shadow(&region, p);
    if (shadow_byte == 0) {
        /* Fast path: all bytes unpoisoned. */
        return;
    } else if (shadow_byte == (uint8_t)-1) {
        /* Fast path: all bytes poisoned. */
        report_bad_access(p, size, rw, caller);
        return;
    }

    /* Slow path: test the specific bit(s) for the accessed byte(s). */
    uintptr_t start = (uintptr_t)p - ALIGN_LEFT((uintptr_t)p, SHADOW_BLKSIZE);
    uint8_t mask = (1 << size) - 1; /* Mask has low `size` bits set. */
    mask <<= start;                 /* Move set bits to the right position. */
    if ((shadow_byte & mask) > 0) { /* Test for set bits between `start` and `size`. */
        report_bad_access(p, size, rw, caller);
    }
}

__attribute__((constructor)) static void iotsdk_asan_init(void)
{
    /* Poison stack & heap.
     */
    struct mem_region heap, stack;
    {
        bool is_valid_region = get_mem_region(IOTSDK_ASAN_SYM_HEAP_START, &heap);
        assert(is_valid_region);
    }
    {
        bool is_valid_region = get_mem_region(IOTSDK_ASAN_SYM_STACK_START, &stack);
        assert(is_valid_region);
    }

    /* Ensure linker-defined shadow area is large enough.
     */
    assert(SHADOW_SIZE >= heap.shadow_size + stack.shadow_size);

    memset(heap.shadow_start, -1, heap.shadow_size);
    memset(stack.shadow_start, -1, stack.shadow_size);
}

/*******************************************************************************
 * ASAN callbacks
 ******************************************************************************/

extern void *IOTSDK_ASAN_REAL_MALLOC(size_t);
extern void IOTSDK_ASAN_REAL_FREE(void *);

void *IOTSDK_ASAN_WRAP_MALLOC(size_t size)
{
    /* Allocate enough space to store for the allocation, its size, and two red-zones.
     */
    size_t full_allocation_size = size + sizeof(size) + 2 * IOTSDK_ASAN_REDZONE_SIZE;
    void *whole_block = IOTSDK_ASAN_REAL_MALLOC(full_allocation_size);
    void *user_memory = whole_block + sizeof(size_t) + IOTSDK_ASAN_REDZONE_SIZE;

    struct mem_region region;
    bool is_valid_region = get_mem_region(whole_block, &region);
    assert(is_valid_region);

    /* Ensure pointer not in quarantine.
     */
    struct quarantine *quarantine = get_quarantine(whole_block);
    if (quarantine) {
        memset(quarantine, 0, sizeof(*quarantine));
    }

    /* Store the size of the pointer before the first red-zone.
     */
    *((size_t *)whole_block) = size;

    /* Ensure redzones and size are poisoned, and user memory is unpoisoned.
     */
    poison_bytes(&region, whole_block, sizeof(size) + size + 2 * IOTSDK_ASAN_REDZONE_SIZE);
    unpoison_bytes(&region, user_memory, size);

    return user_memory;
}

void IOTSDK_ASAN_WRAP_FREE(void *p)
{
    struct mem_region region;
    bool is_valid_region = get_mem_region(p, &region);
    if (!is_valid_region) {
        return;
    }

    struct quarantine *quarantine = get_quarantine(p);
    if (quarantine) {
        report_double_free(p, GET_CALLER_ADDR());
        return;
    }

    /* Find the original block pointer and the size of the user memory.
     */
    void *user_memory = p;
    void *whole_block = user_memory - IOTSDK_ASAN_REDZONE_SIZE - sizeof(size_t);
    size_t size = *((size_t *)whole_block);

    poison_bytes(&region, user_memory, size);
    IOTSDK_ASAN_REAL_FREE(whole_block);

    /* Add to quarantine list, restarting from 0 if it overflows.
     */
    if (quarantine_list_idx == IOTSDK_ASAN_QUARANTINE_SIZE) {
        quarantine_list_idx = 0;
    }

    quarantine_list[quarantine_list_idx].start = whole_block;
    quarantine_list[quarantine_list_idx].end = whole_block + 2 * IOTSDK_ASAN_REDZONE_SIZE + sizeof(size_t) + size;
    quarantine_list_idx++;
}

void __asan_load8_noabort(void *p)
{
    check_shadow_mem(p, 8, READ_ACCESS, GET_CALLER_ADDR());
}

void __asan_store8_noabort(void *p)
{
    check_shadow_mem(p, 8, WRITE_ACCESS, GET_CALLER_ADDR());
}

void __asan_load4_noabort(void *p)
{
    check_shadow_mem(p, 4, READ_ACCESS, GET_CALLER_ADDR());
}

void __asan_store4_noabort(void *p)
{
    check_shadow_mem(p, 4, WRITE_ACCESS, GET_CALLER_ADDR());
}

void __asan_load2_noabort(void *p)
{
    check_shadow_mem(p, 2, READ_ACCESS, GET_CALLER_ADDR());
}

void __asan_store2_noabort(void *p)
{
    check_shadow_mem(p, 2, WRITE_ACCESS, GET_CALLER_ADDR());
}

void __asan_load1_noabort(void *p)
{
    check_shadow_mem(p, 1, READ_ACCESS, GET_CALLER_ADDR());
}

void __asan_store1_noabort(void *p)
{
    check_shadow_mem(p, 1, WRITE_ACCESS, GET_CALLER_ADDR());
}

/* Note: the remaining callbacks are not documented, so they are left unimplemented. */

void __asan_stack_malloc_1(size_t size, void *addr)
{
}

void __asan_stack_malloc_2(size_t size, void *addr)
{
}

void __asan_stack_malloc_3(size_t size, void *addr)
{
}

void __asan_stack_malloc_4(size_t size, void *addr)
{
}

void __asan_stack_malloc_8(size_t size, void *addr)
{
}

void __asan_handle_no_return(void)
{
}

void __asan_option_detect_stack_use_after_return(void)
{
}

void __asan_register_globals(void)
{
}

void __asan_unregister_globals(void)
{
}

void __asan_version_mismatch_check_v8(void)
{
}
